// NOTE: loss is accumulated to whatever init value it already has.
// The x dimension corresponds to num, and y dimension corresponds
// to the spatial dimension (typicall = 1 for global image classification).
// Each block corresponds to a THREADS_PER_BLOCK_X-by-1 subregion in
// the label matrix (n-by-spatial_dim)
//
// NOTE: the loss is a float pointer regardless of whether the input
// is double or float. This is because currently CUDA atomicAdd only
// support up to float.
template <typename T>
__device__ void logistic_loss_forward(T *prob, T *label, T *weights, int num, int spatial_dim, int prob_dim, float *loss) {
  __shared__ T local_loss[THREADS_PER_BLOCK_X];

  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  int idx_sp = blockIdx.y;
  if (idx >= num || idx_sp >= spatial_dim) {
    // out of bounds, set local loss to zero so that we can safely accumulate late
    local_loss[threadIdx.x] = 0;
  } else {
    int y = static_cast<int>(label[idx * spatial_dim + idx_sp]);
    T the_prob = prob[idx * (spatial_dim*prob_dim) + y*spatial_dim + idx_sp];
    T log_prob = -log(max(the_prob, static_cast<T>(LOG_THRESHOLD)));
    if (weights != NULL)
      log_prob *= weights[idx * (spatial_dim*prob_dim) + y*spatial_dim + idx_sp];
    local_loss[threadIdx.x] = log_prob;
  }

  __syncthreads();
  if (0 == threadIdx.x) {
    // thread 0 does in-block accumulation
    T total_local_loss = 0;
    for (int i = 0; i < THREADS_PER_BLOCK_X; ++i)
      total_local_loss += local_loss[i];
    atomicAdd(loss, static_cast<float>(total_local_loss));
  }
}

extern "C" {
  __global__ void logistic_loss_forward_float(float *prob, float *label, float *weights, int num, int spatial_dim, int prob_dim, float *loss) {
    logistic_loss_forward(prob, label, weights, num, spatial_dim, prob_dim, loss);
  }
  __global__ void logistic_loss_forward_double(double *prob, double *label, double *weights, int num, int spatial_dim, int prob_dim, float *loss) {
    logistic_loss_forward(prob, label, weights, num, spatial_dim, prob_dim, loss);
  }
}

// vim: ft=cuda
